aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/ring_reducer.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/ring_reducer.h')
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.h55
1 files changed, 27 insertions, 28 deletions
diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h
index 3e1988e787..0848e37b52 100644
--- a/tensorflow/core/common_runtime/ring_reducer.h
+++ b/tensorflow/core/common_runtime/ring_reducer.h
@@ -16,25 +16,35 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_
#include <deque>
+#include <memory>
+#include <string>
+#include <vector>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class DeviceMgr;
+class Device;
// Ring-algorithm implementation of collective all-reduce.
-class RingReducer {
+class RingReducer : public CollectiveImplementationInterface {
public:
- RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* op_params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, const Tensor* input, Tensor* output);
+ RingReducer();
+ ~RingReducer() override;
- virtual ~RingReducer();
+ // Establishes the requested number of subdivision permutations based on the
+ // ring order implicit in the device order.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the ring reduce algorithm.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
private:
// Called when a bad status is received that implies we should terminate
@@ -101,7 +111,7 @@ class RingReducer {
// For constructing log messages for debugging.
string FieldState();
- string TensorDebugString(Tensor tensor);
+ string TensorDebugString(const Tensor& tensor);
// Producer/Consumer Queue of RingField structs.
class PCQueue {
@@ -116,30 +126,19 @@ class RingReducer {
std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_);
};
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- OpKernelContext::Params* op_params_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const Tensor* input_; // Not owned
- Tensor* output_; // Not owned
- const int rank_;
- const int64 step_id_;
- const int group_size_;
- const int num_subdivs_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
+ StatusCallback done_;
+ int group_size_;
+ int num_subdivs_;
Tensor group_size_tensor_;
Notification group_size_tensor_ready_;
std::unique_ptr<CollectiveAdapter> ca_;
- StatusCallback done_;
- Device* device_; // The device for which this instance labors
- const string device_name_;
- DeviceLocality device_locality_;
-
mutex status_mu_;
Status status_ GUARDED_BY(status_mu_);
-
std::vector<RingField> rfv_;
+
+ friend class RingReducerTest;
};
} // namespace tensorflow