diff options
Diffstat (limited to 'tensorflow/python/keras/utils/multi_gpu_utils.py')
-rw-r--r-- | tensorflow/python/keras/utils/multi_gpu_utils.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py index e1c49bc852..04b2ea8fe3 100644 --- a/tensorflow/python/keras/utils/multi_gpu_utils.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils.py @@ -244,9 +244,24 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False): for o in range(len(outputs)): all_outputs[o].append(outputs[o]) + # Deduplicate output names to handle Siamese networks. + occurrences = {} + for n in model.output_names: + if n not in occurrences: + occurrences[n] = 1 + else: + occurrences[n] += 1 + conflict_counter = {n: 0 for n, count in occurrences.items() if count > 1} + output_names = [] + for n in model.output_names: + if n in conflict_counter: + conflict_counter[n] += 1 + n += '_%d' % conflict_counter[n] + output_names.append(n) + # Merge outputs under expected scope. with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]): merged = [] - for name, outputs in zip(model.output_names, all_outputs): + for name, outputs in zip(output_names, all_outputs): merged.append(concatenate(outputs, axis=0, name=name)) return Model(model.inputs, merged) |