diff options
Diffstat (limited to 'tensorflow/core/common_runtime/collective_param_resolver_local.h')
-rw-r--r-- | tensorflow/core/common_runtime/collective_param_resolver_local.h | 23 |
1 files changed, 8 insertions, 15 deletions
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 2e2aa801d9..c5c3497e28 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ -#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ +#include <functional> +#include <memory> +#include <set> #include <string> +#include <vector> #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -79,6 +83,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // Used to complete/verify CollInstance. struct InstanceRec; + typedef std::function<void(InstanceRec*)> IRConsumer; struct InstanceRec { // This structure has two mutexes so that a possibly long @@ -212,18 +217,6 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec) LOCKS_EXCLUDED(irec->out_mu); - friend class CollectiveParamResolverLocalTest; - // Establishes the requested number of subdivision permutations based on the - // ring order implicit in the device order. - static void GenerateSubdivPerms(const string& device, int source_rank, - CollectiveParams* cp); - // Establishes the subdivisions for broadcast op. The first subdiv executes - // binary tree bcast with one device per task. Each subsequent subdiv - // executes intra-task binary tree broadcast. - static void GenerateBcastSubdivPerms(const string& device, int source_rank, - const std::vector<int>& dev_per_task, - CollectiveParams* cp); - const DeviceMgr* dev_mgr_; DeviceResolverInterface* dev_resolver_; // Not owned. string task_name_; @@ -237,4 +230,4 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { } // namespace tensorflow -#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ |