aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/multi_gpu_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/utils/multi_gpu_utils.py')
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e1c49bc852..04b2ea8fe3 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -244,9 +244,24 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
+ # Deduplicate output names to handle Siamese networks.
+ occurrences = {}
+ for n in model.output_names:
+ if n not in occurrences:
+ occurrences[n] = 1
+ else:
+ occurrences[n] += 1
+ conflict_counter = {n: 0 for n, count in occurrences.items() if count > 1}
+ output_names = []
+ for n in model.output_names:
+ if n in conflict_counter:
+ conflict_counter[n] += 1
+ n += '_%d' % conflict_counter[n]
+ output_names.append(n)
+
# Merge outputs under expected scope.
with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
merged = []
- for name, outputs in zip(model.output_names, all_outputs):
+ for name, outputs in zip(output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)