aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/multi_gpu_utils_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/utils/multi_gpu_utils_test.py')
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 3d0351a11f..1780ab6587 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -198,5 +198,31 @@ class TestMultiGPUModel(test.TestCase):
parallel_model.compile(loss='mean_squared_error', optimizer='adam')
parallel_model.train_on_batch(x, y)
+ def test_multi_gpu_with_siamese_network(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.cached_session():
+ input_shape = (3,)
+ nested_model = keras.models.Sequential([
+ keras.layers.Dense(32, input_shape=input_shape),
+ keras.layers.Dense(1)
+ ], name='nested')
+
+ input1 = keras.Input(input_shape)
+ input2 = keras.Input(input_shape)
+ score1 = nested_model(input1)
+ score2 = nested_model(input2)
+ score_sum = keras.layers.Add(name='add')([score1, score2])
+
+ siamese = keras.models.Model(inputs=[input1, input2],
+ outputs=[score_sum, score1, score2],
+ name='siamese')
+ parallel_siamese = keras.utils.multi_gpu_model(siamese, gpus)
+ self.assertEqual(parallel_siamese.output_names,
+ ['add', 'nested_1', 'nested_2'])
+
if __name__ == '__main__':
test.main()