diff options
Diffstat (limited to 'tensorflow/core/common_runtime/ring_reducer.h')
-rw-r--r-- | tensorflow/core/common_runtime/ring_reducer.h | 55 |
1 files changed, 27 insertions, 28 deletions
diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h index 3e1988e787..0848e37b52 100644 --- a/tensorflow/core/common_runtime/ring_reducer.h +++ b/tensorflow/core/common_runtime/ring_reducer.h @@ -16,25 +16,35 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ #include <deque> +#include <memory> +#include <string> +#include <vector> #include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/framework/collective.h" -#include "tensorflow/core/framework/device_attributes.pb.h" namespace tensorflow { -class DeviceMgr; +class Device; // Ring-algorithm implementation of collective all-reduce. -class RingReducer { +class RingReducer : public CollectiveImplementationInterface { public: - RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, - OpKernelContext* ctx, OpKernelContext::Params* op_params, - const CollectiveParams& col_params, const string& exec_key, - int64 step_id, const Tensor* input, Tensor* output); + RingReducer(); + ~RingReducer() override; - virtual ~RingReducer(); + // Establishes the requested number of subdivision permutations based on the + // ring order implicit in the device order. + Status InitializeCollectiveParams(CollectiveParams* col_params) override; - void Run(StatusCallback done); + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; + + // Begins async execution of the ring reduce algorithm. + // Must be called in a blockable thread. + // TODO(b/80529858): remove the previous warning when we have a dedicated + // collective threadpool. + void Run(StatusCallback done) override; private: // Called when a bad status is received that implies we should terminate @@ -101,7 +111,7 @@ class RingReducer { // For constructing log messages for debugging. string FieldState(); - string TensorDebugString(Tensor tensor); + string TensorDebugString(const Tensor& tensor); // Producer/Consumer Queue of RingField structs. class PCQueue { @@ -116,30 +126,19 @@ class RingReducer { std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_); }; - CollectiveExecutor* col_exec_; // Not owned - const DeviceMgr* dev_mgr_; // Not owned - OpKernelContext* ctx_; // Not owned - OpKernelContext::Params* op_params_; // Not owned - const CollectiveParams& col_params_; - const string exec_key_; - const Tensor* input_; // Not owned - Tensor* output_; // Not owned - const int rank_; - const int64 step_id_; - const int group_size_; - const int num_subdivs_; + CollectiveContext* col_ctx_; // Not owned + const CollectiveParams* col_params_; // Not owned + StatusCallback done_; + int group_size_; + int num_subdivs_; Tensor group_size_tensor_; Notification group_size_tensor_ready_; std::unique_ptr<CollectiveAdapter> ca_; - StatusCallback done_; - Device* device_; // The device for which this instance labors - const string device_name_; - DeviceLocality device_locality_; - mutex status_mu_; Status status_ GUARDED_BY(status_mu_); - std::vector<RingField> rfv_; + + friend class RingReducerTest; }; } // namespace tensorflow |