diff options
Diffstat (limited to 'tensorflow/python/keras/utils/multi_gpu_utils_test.py')
-rw-r--r-- | tensorflow/python/keras/utils/multi_gpu_utils_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py index 3d0351a11f..1780ab6587 100644 --- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py @@ -198,5 +198,31 @@ class TestMultiGPUModel(test.TestCase): parallel_model.compile(loss='mean_squared_error', optimizer='adam') parallel_model.train_on_batch(x, y) + def test_multi_gpu_with_siamese_network(self): + gpus = 2 + + if not check_if_compatible_devices(gpus=gpus): + return + + with self.cached_session(): + input_shape = (3,) + nested_model = keras.models.Sequential([ + keras.layers.Dense(32, input_shape=input_shape), + keras.layers.Dense(1) + ], name='nested') + + input1 = keras.Input(input_shape) + input2 = keras.Input(input_shape) + score1 = nested_model(input1) + score2 = nested_model(input2) + score_sum = keras.layers.Add(name='add')([score1, score2]) + + siamese = keras.models.Model(inputs=[input1, input2], + outputs=[score_sum, score1, score2], + name='siamese') + parallel_siamese = keras.utils.multi_gpu_model(siamese, gpus) + self.assertEqual(parallel_siamese.output_names, + ['add', 'nested_1', 'nested_2']) + if __name__ == '__main__': test.main() |