aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/collective_param_resolver_local.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/collective_param_resolver_local.h')
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h23
1 files changed, 8 insertions, 15 deletions
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 2e2aa801d9..c5c3497e28 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#include <functional>
+#include <memory>
+#include <set>
#include <string>
+#include <vector>
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -79,6 +83,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// Used to complete/verify CollInstance.
struct InstanceRec;
+
typedef std::function<void(InstanceRec*)> IRConsumer;
struct InstanceRec {
// This structure has two mutexes so that a possibly long
@@ -212,18 +217,6 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
LOCKS_EXCLUDED(irec->out_mu);
- friend class CollectiveParamResolverLocalTest;
- // Establishes the requested number of subdivision permutations based on the
- // ring order implicit in the device order.
- static void GenerateSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp);
- // Establishes the subdivisions for broadcast op. The first subdiv executes
- // binary tree bcast with one device per task. Each subsequent subdiv
- // executes intra-task binary tree broadcast.
- static void GenerateBcastSubdivPerms(const string& device, int source_rank,
- const std::vector<int>& dev_per_task,
- CollectiveParams* cp);
-
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
@@ -237,4 +230,4 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_