aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-08-27 14:19:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 14:23:24 -0700
commit85a6164912e21bc398b930943da7ea90ffe3bc20 (patch)
treeaf2efcf298518583c03dc2d7d415cd72df1d60b1 /tensorflow/core/framework
parent59f3c57182fac4d745bb01f3976bb9832c06333d (diff)
Refactor collectives to colocate implementation-specific code.
Before this change, introducing a new collective algorithm required touching multiple files. CollectiveParams setup was in common_runtime/collective_param_resolver_local, and the data movement was in common_runtime/reducer and common_runtime/broadcaster. This change introduces CollectiveImplementationInterface. CollectiveImplementationInterface brings together param initialization and data movement for a collective algorithm. Every collective implementation will implement this interface and override the virtual methods. This should hopefully reduce obscurity and lead to code with fewer dependencies. PiperOrigin-RevId: 210430157
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/collective.cc102
-rw-r--r--tensorflow/core/framework/collective.h113
2 files changed, 211 insertions, 4 deletions
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index d4ac50cbbe..4cb277d5a8 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -21,6 +21,31 @@ limitations under the License.
namespace tensorflow {
+namespace {
+// A RegistrationInfo object stores a collective implementation registration
+// details. `factory` is used to create instances of the collective
+// implementation.
+struct RegistrationInfo {
+ // This constructor also creates, and stores in `param_resolver_instance`,
+ // what is effectively a static instance of the collective implementation.
+ // During param resolution of collective ops we return this static instance.
+ // The actual op execution gets a fresh instance using `factory`.
+ RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
+ : name(n),
+ factory(std::move(f)),
+ param_resolver_instance(this->factory()) {}
+ string name;
+ CollectiveRegistry::Factory factory;
+ CollectiveImplementationInterface* param_resolver_instance;
+};
+
+std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
+ static std::vector<RegistrationInfo>* registry =
+ new std::vector<RegistrationInfo>;
+ return registry;
+}
+} // namespace
+
string CollGroupParams::ToString() const {
return strings::StrCat("CollGroupParams {group_key=", group_key,
" group_size=", group_size,
@@ -102,7 +127,8 @@ string CollectiveParams::ToString() const {
strings::StrAppend(&v, " ", instance.ToString());
strings::StrAppend(&v, " ", task.ToString());
strings::StrAppend(&v, " default_rank=", default_rank,
- " is_source=", is_source, " subdiv_rank={");
+ " is_source=", is_source, " source_rank=", source_rank,
+ " subdiv_rank={");
for (const auto& r : subdiv_rank) {
strings::StrAppend(&v, r, ",");
}
@@ -115,7 +141,81 @@ string CollectiveParams::ToString() const {
return ctx->params_;
}
+CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
+ const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx,
+ OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params,
+ const string& exec_key, int64 step_id,
+ const Tensor* input, Tensor* output)
+ : col_exec(col_exec),
+ dev_mgr(dev_mgr),
+ op_ctx(ctx),
+ op_params(op_params),
+ col_params(col_params),
+ exec_key(exec_key),
+ step_id(step_id),
+ input(input),
+ output(output),
+ device(nullptr),
+ device_name(col_params.instance.device_names[col_params.default_rank]) {}
+
/*static*/
int64 CollectiveExecutor::kInvalidId = -1;
+/*static*/
+Status CollectiveRegistry::Lookup(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, false);
+}
+
+/*static*/
+Status CollectiveRegistry::LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, true);
+}
+
+/*static*/
+void CollectiveRegistry::GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry)
+ implementations->emplace_back(reg_info.factory());
+}
+
+/*static*/
+Status CollectiveRegistry::Register(const string& collective_name,
+ Factory factory) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name)
+ return errors::Internal("Already registered collective ",
+ collective_name);
+ }
+ registry->emplace_back(collective_name, std::move(factory));
+ return Status::OK();
+}
+
+/*static*/
+Status CollectiveRegistry::LookupHelper(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation, bool param_resolver) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name) {
+ if (param_resolver) {
+ *implementation = reg_info.param_resolver_instance;
+ } else {
+ *implementation = reg_info.factory();
+ }
+ return Status::OK();
+ }
+ }
+ return errors::Internal(
+ "CollectiveRegistry::Lookup did not find collective implementation ",
+ collective_name);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 0b37b3a88c..e35edb09d0 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -30,7 +31,8 @@ class CompleteGroupRequest;
class CompleteGroupResponse;
class CompleteInstanceRequest;
class CompleteInstanceResponse;
-class DeviceLocality;
+class Device;
+class DeviceMgr;
class GetStepSequenceRequest;
class GetStepSequenceResponse;
class Op;
@@ -64,10 +66,10 @@ struct CollGroupParams {
// interpretation. On first execution the runtime will update this
// structure with decisions that will guide all subsequent executions.
struct CollImplDetails {
+ string collective_name;
std::vector<std::vector<int>> subdiv_permutations;
std::vector<int> subdiv_offsets;
- // broadcast only: rank of source in each subdiv
- std::vector<int> subdiv_source_rank;
+ std::vector<int> subdiv_source_rank; // rank of source in each subdiv
};
// Data common to all members of a collective instance.
@@ -104,6 +106,7 @@ struct CollectiveParams {
string name = ""; // node name used only for log or error messages
int default_rank = -1; // index of this op within device_names
bool is_source = false; // broadcast only
+ int source_rank = -1; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::unique_ptr<OpKernel> merge_op; // reduction only
@@ -306,6 +309,110 @@ class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
virtual void StartAbort(const Status& s) = 0;
};
+class CollectiveContext {
+ public:
+ CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx, OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params, const string& exec_key,
+ int64 step_id, const Tensor* input, Tensor* output);
+
+ virtual ~CollectiveContext() = default;
+
+ CollectiveExecutor* col_exec; // Not owned
+ const DeviceMgr* dev_mgr; // Not owned
+ OpKernelContext* op_ctx; // Not owned
+ OpKernelContext::Params* op_params; // Not owned
+ const CollectiveParams& col_params;
+ const string exec_key;
+ const int64 step_id;
+ const Tensor* input; // Not owned
+ Tensor* output; // Not owned
+ Device* device; // The device for which this instance labors
+ const string device_name;
+ DeviceLocality device_locality;
+};
+
+// Interface of a Collective Op implementation. Each specific CollectiveOp will
+// implement this interface and register the implementation via the
+// CollectiveRegistry detailed below. See common_runtime/ring_reducer and
+// common_runtime/hierarchical_tree_broadcaster for examples.
+class CollectiveImplementationInterface {
+ public:
+ virtual ~CollectiveImplementationInterface() = default;
+
+ // Initializes the portions of `col_params` specific to this
+ // implementation. Called exactly once for every Collective instance during
+ // the CollectiveParams resolution process when the graph is first executed.
+ // NOTE(ayushd): This is effectively a static function because it modifies the
+ // `col_params` passed in and should not manipulate any data members. However
+ // because it is virtual and needs to be implemented by every derived class we
+ // do not mark it as static.
+ virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
+
+ // Prepares the CollectiveContext for executing this CollectiveImplementation.
+ // Called from CollectiveExecutor right before calling Run(). The
+ // CollectiveContext passed in must outlive the CollectiveImplementation
+ // object.
+ virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
+
+ // Processes and moves data according to the logic of this Collective
+ // implementation. Relies on appropriate initialization of op-specific
+ // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
+ // context initialization in InitializeCollectiveContext().
+ virtual void Run(StatusCallback done) = 0;
+};
+
+// Static-methods only class for registering and looking up collective
+// implementations.
+class CollectiveRegistry {
+ public:
+ using Factory = std::function<CollectiveImplementationInterface*()>;
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, creates an instance of the implementation and
+ // assign to `implementation`.
+ static Status Lookup(const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, returns the static instance of this
+ // implementation via `implementation`. This instance should only be used to
+ // call InitializateCollectiveParams.
+ static Status LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Returns all registered collective implementations.
+ static void GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations);
+
+ private:
+ friend class CollectiveRegistration;
+ // Registers a CollectiveImplementation with name `collective_name` and
+ // factory `factory`. The latter is a function used to create instances of
+ // the CollectiveImplementation. Also creates a static instance of the
+ // implementation - this instance is used during param resolution and should
+ // only be used to call InitializeCollectiveParams.
+ static Status Register(const string& collective_name, Factory factory);
+
+ static Status LookupHelper(const string& collective_name,
+ CollectiveImplementationInterface** implementation,
+ bool param_resolver);
+};
+
+// Class used to call CollectiveRegistry::Register. This should only be used to
+// create a global static object.
+class CollectiveRegistration {
+ public:
+ CollectiveRegistration(const string& collective_name,
+ CollectiveRegistry::Factory factory) {
+ TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
+ }
+};
+
+#define REGISTER_COLLECTIVE(name, implementation) \
+ static CollectiveRegistration register_##name##_collective( \
+ #name, []() { return new implementation; });
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_