aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 8add2aacff..159d985db5 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -18,10 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import math
-import re
from tensorflow.contrib import nccl
+from tensorflow.python.framework import device as device_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -659,21 +660,20 @@ def _split_by_task(devices, values):
num_devices = len(devices)
if num_devices != len(values):
raise ValueError("len(devices) must equal len(values)")
- pattern = re.compile(r"/task:(\d+)/")
- per_task_devices = []
- per_task_values = []
+ per_task_devices = collections.OrderedDict()
+ per_task_values = collections.OrderedDict()
for d in range(num_devices):
- m = pattern.search(devices[d])
- if m:
- index = int(m.group(1))
- while index >= len(per_task_devices):
- per_task_devices.append([])
- per_task_values.append([])
- per_task_devices[index].append(devices[d])
- per_task_values[index].append(values[d])
- else:
+ d_spec = device_lib.DeviceSpec.from_string(devices[d])
+ if not hasattr(d_spec, "task") or d_spec.task is None:
assert False, "failed to parse device %s" % devices[d]
- return (per_task_devices, per_task_values)
+ index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
+ if index not in per_task_devices:
+ per_task_devices[index] = []
+ per_task_values[index] = []
+ per_task_devices[index].append(devices[d])
+ per_task_values[index].append(values[d])
+
+ return (list(per_task_devices.values()), list(per_task_values.values()))
def build_nccl_all_reduce(input_tensors, red_op, un_op=None):