diff options
author | 2018-08-27 14:19:20 -0700 | |
---|---|---|
committer | 2018-08-27 14:23:24 -0700 | |
commit | 85a6164912e21bc398b930943da7ea90ffe3bc20 (patch) | |
tree | af2efcf298518583c03dc2d7d415cd72df1d60b1 /tensorflow/core/framework | |
parent | 59f3c57182fac4d745bb01f3976bb9832c06333d (diff) |
Refactor collectives to colocate implementation-specific code.
Before this change, introducing a new collective algorithm required touching
multiple files. CollectiveParams setup was in common_runtime/collective_param_resolver_local,
and the data movement was in common_runtime/reducer and common_runtime/broadcaster.
This change introduces CollectiveImplementationInterface.
CollectiveImplementationInterface brings together param initialization and data
movement for a collective algorithm. Every collective implementation will
implement this interface and override the virtual methods. This should
hopefully reduce obscurity and lead to code with fewer dependencies.
PiperOrigin-RevId: 210430157
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/collective.cc | 102 | ||||
-rw-r--r-- | tensorflow/core/framework/collective.h | 113 |
2 files changed, 211 insertions, 4 deletions
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index d4ac50cbbe..4cb277d5a8 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -21,6 +21,31 @@ limitations under the License. namespace tensorflow { +namespace { +// A RegistrationInfo object stores a collective implementation registration +// details. `factory` is used to create instances of the collective +// implementation. +struct RegistrationInfo { + // This constructor also creates, and stores in `param_resolver_instance`, + // what is effectively a static instance of the collective implementation. + // During param resolution of collective ops we return this static instance. + // The actual op execution gets a fresh instance using `factory`. + RegistrationInfo(const string& n, CollectiveRegistry::Factory f) + : name(n), + factory(std::move(f)), + param_resolver_instance(this->factory()) {} + string name; + CollectiveRegistry::Factory factory; + CollectiveImplementationInterface* param_resolver_instance; +}; + +std::vector<RegistrationInfo>* MutableCollectiveRegistry() { + static std::vector<RegistrationInfo>* registry = + new std::vector<RegistrationInfo>; + return registry; +} +} // namespace + string CollGroupParams::ToString() const { return strings::StrCat("CollGroupParams {group_key=", group_key, " group_size=", group_size, @@ -102,7 +127,8 @@ string CollectiveParams::ToString() const { strings::StrAppend(&v, " ", instance.ToString()); strings::StrAppend(&v, " ", task.ToString()); strings::StrAppend(&v, " default_rank=", default_rank, - " is_source=", is_source, " subdiv_rank={"); + " is_source=", is_source, " source_rank=", source_rank, + " subdiv_rank={"); for (const auto& r : subdiv_rank) { strings::StrAppend(&v, r, ","); } @@ -115,7 +141,81 @@ string CollectiveParams::ToString() const { return ctx->params_; } +CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec, + const DeviceMgr* dev_mgr, + OpKernelContext* ctx, + OpKernelContext::Params* op_params, + const CollectiveParams& col_params, + const string& exec_key, int64 step_id, + const Tensor* input, Tensor* output) + : col_exec(col_exec), + dev_mgr(dev_mgr), + op_ctx(ctx), + op_params(op_params), + col_params(col_params), + exec_key(exec_key), + step_id(step_id), + input(input), + output(output), + device(nullptr), + device_name(col_params.instance.device_names[col_params.default_rank]) {} + /*static*/ int64 CollectiveExecutor::kInvalidId = -1; +/*static*/ +Status CollectiveRegistry::Lookup( + const string& collective_name, + CollectiveImplementationInterface** implementation) { + return LookupHelper(collective_name, implementation, false); +} + +/*static*/ +Status CollectiveRegistry::LookupParamResolverInstance( + const string& collective_name, + CollectiveImplementationInterface** implementation) { + return LookupHelper(collective_name, implementation, true); +} + +/*static*/ +void CollectiveRegistry::GetAll( + std::vector<CollectiveImplementationInterface*>* implementations) { + std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); + for (const RegistrationInfo& reg_info : *registry) + implementations->emplace_back(reg_info.factory()); +} + +/*static*/ +Status CollectiveRegistry::Register(const string& collective_name, + Factory factory) { + std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); + for (const RegistrationInfo& reg_info : *registry) { + if (reg_info.name == collective_name) + return errors::Internal("Already registered collective ", + collective_name); + } + registry->emplace_back(collective_name, std::move(factory)); + return Status::OK(); +} + +/*static*/ +Status CollectiveRegistry::LookupHelper( + const string& collective_name, + CollectiveImplementationInterface** implementation, bool param_resolver) { + std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); + for (const RegistrationInfo& reg_info : *registry) { + if (reg_info.name == collective_name) { + if (param_resolver) { + *implementation = reg_info.param_resolver_instance; + } else { + *implementation = reg_info.factory(); + } + return Status::OK(); + } + } + return errors::Internal( + "CollectiveRegistry::Lookup did not find collective implementation ", + collective_name); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 0b37b3a88c..e35edb09d0 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include <vector> +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/refcount.h" @@ -30,7 +31,8 @@ class CompleteGroupRequest; class CompleteGroupResponse; class CompleteInstanceRequest; class CompleteInstanceResponse; -class DeviceLocality; +class Device; +class DeviceMgr; class GetStepSequenceRequest; class GetStepSequenceResponse; class Op; @@ -64,10 +66,10 @@ struct CollGroupParams { // interpretation. On first execution the runtime will update this // structure with decisions that will guide all subsequent executions. struct CollImplDetails { + string collective_name; std::vector<std::vector<int>> subdiv_permutations; std::vector<int> subdiv_offsets; - // broadcast only: rank of source in each subdiv - std::vector<int> subdiv_source_rank; + std::vector<int> subdiv_source_rank; // rank of source in each subdiv }; // Data common to all members of a collective instance. @@ -104,6 +106,7 @@ struct CollectiveParams { string name = ""; // node name used only for log or error messages int default_rank = -1; // index of this op within device_names bool is_source = false; // broadcast only + int source_rank = -1; // broadcast only // Rank of this device in each subdivision permutation. std::vector<int> subdiv_rank; std::unique_ptr<OpKernel> merge_op; // reduction only @@ -306,6 +309,110 @@ class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess { virtual void StartAbort(const Status& s) = 0; }; +class CollectiveContext { + public: + CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, + OpKernelContext* ctx, OpKernelContext::Params* op_params, + const CollectiveParams& col_params, const string& exec_key, + int64 step_id, const Tensor* input, Tensor* output); + + virtual ~CollectiveContext() = default; + + CollectiveExecutor* col_exec; // Not owned + const DeviceMgr* dev_mgr; // Not owned + OpKernelContext* op_ctx; // Not owned + OpKernelContext::Params* op_params; // Not owned + const CollectiveParams& col_params; + const string exec_key; + const int64 step_id; + const Tensor* input; // Not owned + Tensor* output; // Not owned + Device* device; // The device for which this instance labors + const string device_name; + DeviceLocality device_locality; +}; + +// Interface of a Collective Op implementation. Each specific CollectiveOp will +// implement this interface and register the implementation via the +// CollectiveRegistry detailed below. See common_runtime/ring_reducer and +// common_runtime/hierarchical_tree_broadcaster for examples. +class CollectiveImplementationInterface { + public: + virtual ~CollectiveImplementationInterface() = default; + + // Initializes the portions of `col_params` specific to this + // implementation. Called exactly once for every Collective instance during + // the CollectiveParams resolution process when the graph is first executed. + // NOTE(ayushd): This is effectively a static function because it modifies the + // `col_params` passed in and should not manipulate any data members. However + // because it is virtual and needs to be implemented by every derived class we + // do not mark it as static. + virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; + + // Prepares the CollectiveContext for executing this CollectiveImplementation. + // Called from CollectiveExecutor right before calling Run(). The + // CollectiveContext passed in must outlive the CollectiveImplementation + // object. + virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0; + + // Processes and moves data according to the logic of this Collective + // implementation. Relies on appropriate initialization of op-specific + // CollectiveParams in InitializeCollectiveParams(), as well as appropriate + // context initialization in InitializeCollectiveContext(). + virtual void Run(StatusCallback done) = 0; +}; + +// Static-methods only class for registering and looking up collective +// implementations. +class CollectiveRegistry { + public: + using Factory = std::function<CollectiveImplementationInterface*()>; + // Looks up a previously registered CollectiveImplementation under + // `collective_name`. If found, creates an instance of the implementation and + // assign to `implementation`. + static Status Lookup(const string& collective_name, + CollectiveImplementationInterface** implementation); + + // Looks up a previously registered CollectiveImplementation under + // `collective_name`. If found, returns the static instance of this + // implementation via `implementation`. This instance should only be used to + // call InitializateCollectiveParams. + static Status LookupParamResolverInstance( + const string& collective_name, + CollectiveImplementationInterface** implementation); + + // Returns all registered collective implementations. + static void GetAll( + std::vector<CollectiveImplementationInterface*>* implementations); + + private: + friend class CollectiveRegistration; + // Registers a CollectiveImplementation with name `collective_name` and + // factory `factory`. The latter is a function used to create instances of + // the CollectiveImplementation. Also creates a static instance of the + // implementation - this instance is used during param resolution and should + // only be used to call InitializeCollectiveParams. + static Status Register(const string& collective_name, Factory factory); + + static Status LookupHelper(const string& collective_name, + CollectiveImplementationInterface** implementation, + bool param_resolver); +}; + +// Class used to call CollectiveRegistry::Register. This should only be used to +// create a global static object. +class CollectiveRegistration { + public: + CollectiveRegistration(const string& collective_name, + CollectiveRegistry::Factory factory) { + TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); + } +}; + +#define REGISTER_COLLECTIVE(name, implementation) \ + static CollectiveRegistration register_##name##_collective( \ + #name, []() { return new implementation; }); + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ |